"""Eval mAP@N metric from inference file."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags

import mean_average_precision_calculator as map_calculator
import numpy as np
import tensorflow as tf

flags.DEFINE_string(
    "eval_data_pattern", "",
    "File glob defining the evaluation dataset in tensorflow.SequenceExample "
    "format. The SequenceExamples are expected to have an 'rgb' byte array "
    "sequence feature as well as a 'labels' int64 context feature.")
flags.DEFINE_string(
    "label_cache", "",
    "The path for the label cache file. Leave blank for not to cache.")
flags.DEFINE_string("submission_file", "",
                    "The segment submission file generated by inference.py.")
flags.DEFINE_integer(
    "top_n", 0,
    "The cap per-class predictions by a maximum of N. Use 0 for not capping.")

FLAGS = flags.FLAGS


class Labels(object):
  """Contains the class to hold label objects.

  This class can serialize and de-serialize the groundtruths.
  The ground truth is in a mapping from (segment_id, class_id) -> label_score.
  """

  def __init__(self, labels):
    """__init__ method."""
    self._labels = labels

  @property
  def labels(self):
    """Return the ground truth mapping. See class docstring for details."""
    return self._labels

  def to_file(self, file_name):
    """Materialize the GT mapping to file."""
    with tf.gfile.Open(file_name, "w") as fobj:
      for k, v in self._labels.items():
        seg_id, label = k
        line = "%s,%s,%s\n" % (seg_id, label, v)
        fobj.write(line)

  @classmethod
  def from_file(cls, file_name):
    """Read the GT mapping from cached file."""
    labels = {}
    with tf.gfile.Open(file_name) as fobj:
      for line in fobj:
        line = line.strip().strip("\n")
        seg_id, label, score = line.split(",")
        labels[(seg_id, int(label))] = float(score)
    return cls(labels)


def read_labels(data_pattern, cache_path=""):
  """Read labels from TFRecords.

  Args:
    data_pattern: the data pattern to the TFRecords.
    cache_path: the cache path for the label file.

  Returns:
    a Labels object.
  """
  if cache_path:
    if tf.gfile.Exists(cache_path):
      tf.logging.info("Reading cached labels from %s..." % cache_path)
      return Labels.from_file(cache_path)
  tf.enable_eager_execution()
  data_paths = tf.gfile.Glob(data_pattern)
  ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50)
  context_features = {
      "id": tf.FixedLenFeature([], tf.string),
      "segment_labels": tf.VarLenFeature(tf.int64),
      "segment_start_times": tf.VarLenFeature(tf.int64),
      "segment_scores": tf.VarLenFeature(tf.float32)
  }

  def _parse_se_func(sequence_example):
    return tf.parse_single_sequence_example(
        sequence_example, context_features=context_features)

  ds = ds.map(_parse_se_func)
  rated_labels = {}
  tf.logging.info("Reading labels from TFRecords...")
  last_batch = 0
  batch_size = 5000
  for cxt_feature_val, _ in ds:
    video_id = cxt_feature_val["id"].numpy()
    segment_labels = cxt_feature_val["segment_labels"].values.numpy()
    segment_start_times = cxt_feature_val["segment_start_times"].values.numpy()
    segment_scores = cxt_feature_val["segment_scores"].values.numpy()
    for label, start_time, score in zip(segment_labels, segment_start_times,
                                        segment_scores):
      rated_labels[("%s:%d" % (video_id, start_time), label)] = score
    batch_id = len(rated_labels) // batch_size
    if batch_id != last_batch:
      tf.logging.info("%d examples processed.", len(rated_labels))
      last_batch = batch_id
  tf.logging.info("Finish reading labels from TFRecords...")
  labels_obj = Labels(rated_labels)
  if cache_path:
    tf.logging.info("Caching labels to %s..." % cache_path)
    labels_obj.to_file(cache_path)
  return labels_obj


def read_segment_predictions(file_path, labels, top_n=None):
  """Read segement predictions.

  Args:
    file_path: the submission file path.
    labels: a Labels object containing the eval labels.
    top_n: the per-class class capping.

  Returns:
    a segment prediction list for each classes.
  """
  cls_preds = {}  # A label_id to pred list mapping.
  with tf.gfile.Open(file_path) as fobj:
    tf.logging.info("Reading predictions from %s..." % file_path)
    for line in fobj:
      label_id, pred_ids_val = line.split(",")
      pred_ids = pred_ids_val.split(" ")
      if top_n:
        pred_ids = pred_ids[:top_n]
      pred_ids = [
          pred_id for pred_id in pred_ids
          if (pred_id, int(label_id)) in labels.labels
      ]
      cls_preds[int(label_id)] = pred_ids
      if len(cls_preds) % 50 == 0:
        tf.logging.info("Processed %d classes..." % len(cls_preds))
    tf.logging.info("Finish reading predictions.")
  return cls_preds


def main(unused_argv):
  """Entry function of the script."""
  if not FLAGS.submission_file:
    raise ValueError("You must input submission file.")
  eval_labels = read_labels(
      FLAGS.eval_data_pattern, cache_path=FLAGS.label_cache)
  tf.logging.info("Total rated segments: %d." % len(eval_labels.labels))
  positive_counter = {}
  for k, v in eval_labels.labels.items():
    _, label_id = k
    if v > 0:
      positive_counter[label_id] = positive_counter.get(label_id, 0) + 1

  seg_preds = read_segment_predictions(
      FLAGS.submission_file, eval_labels, top_n=FLAGS.top_n)
  map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds))
  seg_labels = []
  seg_scored_preds = []
  num_positives = []
  for label_id in sorted(seg_preds):
    class_preds = seg_preds[label_id]
    seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds]
    seg_labels.append(seg_label)
    seg_scored_pred = []
    if class_preds:
      seg_scored_pred = [
          float(x) / len(class_preds) for x in range(len(class_preds), 0, -1)
      ]
    seg_scored_preds.append(seg_scored_pred)
    num_positives.append(positive_counter[label_id])
  map_cal.accumulate(seg_scored_preds, seg_labels, num_positives)
  map_at_n = np.mean(map_cal.peek_map_at_n())
  tf.logging.info("Num classes: %d | mAP@%d: %.6f" %
                  (len(seg_preds), FLAGS.top_n, map_at_n))


if __name__ == "__main__":
  app.run(main)
